# Deep Kernel (DKL) code aligned with your Function Encoder setup
import matplotlib.pyplot as plt
import tqdm
import torch
from torch import nn
from torch.nn import functional as F

from torch.utils.data import DataLoader
from my_datasets.polynomial import PolynomialDataset

import time
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

# --- device ---
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

torch.manual_seed(42)


# Load dataset
dataset = PolynomialDataset(n_points=100, n_example_points=20)
dataloader = DataLoader(dataset, batch_size=100)
dataloader_iter = iter(dataloader)


class FeatureNet(nn.Module):
    def __init__(self, in_dim=1, out_dim=4, hidden=(8, 8, 8, 8)):
        super().__init__()
        layers = []
        d = in_dim
        for h in hidden:
            layers += [nn.Linear(d, h), nn.ReLU()]
            d = h
        layers += [nn.Linear(d, out_dim)]
        self.net = nn.Sequential(*layers)

    def forward(self, x):          # x: [*, 1]
        return self.net(x)         # -> [*, 9]


theta = FeatureNet(in_dim=1, out_dim=4, hidden=(64,)).to(device)


class DeepRBFKernel(nn.Module):
    def __init__(self, dim, log_ls0=0.0, log_sigma0=0.0, log_lam0=0.0):
        super().__init__()
        self.log_ls = nn.Parameter(torch.full(
            (dim,), float(log_ls0)))   # ARD lengthscales
        self.log_sigma = nn.Parameter(
            torch.tensor(float(log_sigma0)))   # output scale
        self.log_lam = nn.Parameter(torch.tensor(float(log_lam0)))     # ridge

    def forward(self, zi, zj):
        ls = self.log_ls.exp()
        sig = self.log_sigma.exp()
        zi_s = zi / ls
        zj_s = zj / ls
        zi2 = (zi_s**2).sum(-1, keepdim=True)
        zj2 = (zj_s**2).sum(-1).unsqueeze(0)
        sqdist = (zi2 + zj2 - 2.0 * (zi_s @ zj_s.T)).clamp_min(0.0)
        return sig * torch.exp(-0.5 * sqdist)


kernel = DeepRBFKernel(dim=4, log_ls0=1.0, log_sigma0=1.0,
                       log_lam0=-9).to(device)


def dkl_batch_loss(X, y, example_X, example_y):
    B, P, _ = X.shape
    _, Ps, _ = example_X.shape
    lam = kernel.log_lam.exp()

    # embeddings
    Z_S = theta(example_X.view(B*Ps, 1)).view(B, Ps, -1)  # [B, Ps, D]
    Z_all = theta(X.view(B*P, 1)).view(B, P, -1)          # [B, P, D]

    total = 0.0
    dtype = X.dtype

    for b in range(B):
        zS = Z_S[b]        # [Ps, D]
        zA = Z_all[b]      # [P, D]
        yS = example_y[b]  # [Ps, 1]
        yA = y[b]          # [P, 1]

        K_SS = kernel(zS, zS)            # [Ps, Ps]
        K_allS = kernel(zA, zS)            # [P, Ps]

        A = K_SS + lam * torch.eye(Ps, device=K_SS.device, dtype=dtype)
        alpha = torch.linalg.solve(A, yS)     # [Ps, 1]

        y_hat = K_allS @ alpha                # [P, 1]
        total = total + F.mse_loss(y_hat, yA)

    return total / B


num_epochs = 500
lr = 1e-2
opt = torch.optim.Adam([
    {"params": theta.parameters()},
    {"params": kernel.parameters(), "weight_decay": 0.0}
], lr=lr)

start = time.perf_counter()
with tqdm.tqdm(range(num_epochs)) as tqdm_bar:
    for epoch in tqdm_bar:
        batch = next(dataloader_iter)
        X, y, example_X, example_y = batch
        X = X.to(device)
        y = y.to(device)
        example_X = example_X.to(device)
        example_y = example_y.to(device)

        opt.zero_grad(set_to_none=True)
        loss = dkl_batch_loss(X, y, example_X, example_y)
        loss.backward()
        opt.step()
        tqdm_bar.set_postfix({"loss": f"{loss:.2e}"})
end = time.perf_counter()
print(f"Wall time training: {end - start:.6f} s")

theta.eval()
kernel.eval()
with torch.no_grad():
    torch.manual_seed(123)
    val_loader = DataLoader(dataset, batch_size=1)
    X, y, example_X, example_y = next(iter(val_loader))

    X = X.to(device)
    y = y.to(device)
    example_X = example_X.to(device)
    example_y = example_y.to(device)

    idx = torch.argsort(X, dim=1, descending=False)
    X_sorted = torch.gather(X, dim=1, index=idx)
    y_sorted = torch.gather(y, dim=1, index=idx)

    B, P, _ = X.shape
    _, Ps, _ = example_X.shape
    start = time.perf_counter()

    Z_S = theta(example_X.view(B*Ps, 1)).view(B, Ps, -1)  # [1,Ps,D]
    Z_all = theta(X_sorted.view(B*P, 1)).view(B, P, -1)     # [1,P,D]

    zS = Z_S[0]         # [Ps, D]
    zA = Z_all[0]       # [P, D]
    yS = example_y[0]   # [Ps, 1]
    yA = y_sorted[0]    # [P, 1]

    K_SS = kernel(zS, zS)             # [Ps, Ps]
    K_allS = kernel(zA, zS)             # [P, Ps]
    lam = kernel.log_lam.exp()

    A = K_SS + lam * torch.eye(Ps, device=K_SS.device, dtype=K_SS.dtype)
    alpha = torch.linalg.solve(A, yS)
    y_hat_all = K_allS @ alpha          # [P,1]

    end = time.perf_counter()
    print(f"Wall time prediction: {end - start:.6f} s")

    mse_val = F.mse_loss(y_hat_all, yA)
    print("Validation MSE:", mse_val.item())

    # for plotting
    X_np = X_sorted.squeeze(0).cpu().numpy()
    y_np = y_sorted.squeeze(0).cpu().numpy()
    yhat_np = y_hat_all.squeeze(0).cpu().numpy()
    exX_np = example_X.squeeze(0).cpu().numpy()
    exy_np = example_y.squeeze(0).cpu().numpy()

    fig, ax = plt.subplots()
    ax.plot(X_np, y_np, label="True")
    ax.plot(X_np, yhat_np, label="Predicted")
    ax.scatter(exX_np, exy_np, label="Support (example)", color="red")
    ax.legend()
    plt.show()
    plt.savefig("polynomial_dkl.png")
